{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX 实用工具\n",
    "\n",
    "列出一些常用的 ONNX 工具："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting onnx_helper.py\n"
     ]
    }
   ],
   "source": [
    "%%file onnx_helper.py\n",
    "import numpy as np\n",
    "import onnx\n",
    "from onnx import ModelProto, mapping\n",
    "\n",
    "bg = np.random.MT19937(0)\n",
    "rg = np.random.Generator(bg)\n",
    "\n",
    "def generate_random_inputs(\n",
    "    model: ModelProto, inputs: dict[str, np.ndarray] | None = None\n",
    ") -> dict[str, np.ndarray]:\n",
    "    \"\"\"为ONNX模型生成随机输入数据\n",
    "    \n",
    "    参数:\n",
    "        model: ONNX模型对象(ModelProto 类型)\n",
    "        inputs: 可选参数，预定义的输入字典，键为输入名称，值为numpy数组\n",
    "        \n",
    "    返回:\n",
    "        包含所有输入名称和对应随机值的字典\n",
    "    \"\"\"\n",
    "    input_values = {}\n",
    "    # 遍历模型的所有输入节点\n",
    "    for i in model.graph.input:\n",
    "        # 如果输入已提供且不为None，则直接使用提供的值\n",
    "        if inputs is not None and i.name in inputs and inputs[i.name] is not None:\n",
    "            input_values[i.name] = inputs[i.name]\n",
    "            continue\n",
    "            \n",
    "        # 提取输入张量的形状信息\n",
    "        shape = [dim.dim_value for dim in i.type.tensor_type.shape.dim]\n",
    "\n",
    "        # 生成符合形状和数据类型的随机值\n",
    "        input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)\n",
    "\n",
    "    return input_values\n",
    "\n",
    "\n",
    "def generate_random_value(shape, elem_type) -> np.ndarray:\n",
    "    \"\"\"根据形状和数据类型生成随机数值数组\n",
    "    \n",
    "    参数:\n",
    "        shape: 数组形状(tuple/list)\n",
    "        elem_type: ONNX张量元素类型\n",
    "        \n",
    "    返回:\n",
    "        符合指定形状和数据类型的随机numpy数组\n",
    "    \"\"\"\n",
    "    # 从ONNX类型映射获取对应的numpy数据类型\n",
    "    if elem_type:\n",
    "        dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])\n",
    "    else:\n",
    "        dtype = \"float32\"  # 默认使用float32类型\n",
    "\n",
    "    # 根据不同类型生成随机值\n",
    "    if dtype == \"bool\":\n",
    "        # 生成布尔型随机值(True/False)\n",
    "        random_value = rg.choice(a=[False, True], size=shape)\n",
    "    elif dtype.startswith(\"int\"):\n",
    "        # 生成整型随机值，并确保非零\n",
    "        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)\n",
    "        random_value[random_value <= 0] -= 1  # 使所有值非零\n",
    "    else:\n",
    "        # 生成浮点型随机值(标准正态分布)\n",
    "        random_value = rg.standard_normal(size=shape).astype(dtype)\n",
    "\n",
    "    return random_value"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
